{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Custom Layers & Modules\n",
    "\n",
    "Build custom layers and modules with TensorFlow v2.\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import Model, layers\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MNIST dataset parameters.\n",
    "num_classes = 10 # 0 to 9 digits\n",
    "num_features = 784 # 28*28\n",
    "\n",
    "# Training parameters.\n",
    "learning_rate = 0.01\n",
    "training_steps = 500\n",
    "batch_size = 256\n",
    "display_step = 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare MNIST data.\n",
    "from tensorflow.keras.datasets import mnist\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "# Convert to float32.\n",
    "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n",
    "# Flatten images to 1-D vector of 784 features (28*28).\n",
    "x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])\n",
    "# Normalize images value from [0, 255] to [0, 1].\n",
    "x_train, x_test = x_train / 255., x_test / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use tf.data API to shuffle and batch data.\n",
    "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a custom layer\n",
    "\n",
    "Build a custom layer with inner-variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a custom layer, extending TF 'Layer' class.\n",
    "# Layer compute: y = relu(W * x + b)\n",
    "class CustomLayer1(layers.Layer):\n",
    "    \n",
    "    # Layer arguments.\n",
    "    def __init__(self, num_units, **kwargs):\n",
    "        # Store the number of units (neurons).\n",
    "        self.num_units = num_units\n",
    "        super(CustomLayer1, self).__init__(**kwargs)\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        # Note: a custom layer can also include any other TF 'layers'.\n",
    "        shape = tf.TensorShape((input_shape[1], self.num_units))\n",
    "        # Create weight variables for this layer.\n",
    "        self.weight = self.add_weight(name='W',\n",
    "                                      shape=shape,\n",
    "                                      initializer=tf.initializers.RandomNormal,\n",
    "                                      trainable=True)\n",
    "        self.bias = self.add_weight(name='b',\n",
    "                                    shape=[self.num_units])\n",
    "        # Make sure to call the `build` method at the end\n",
    "        super(CustomLayer1, self).build(input_shape)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = tf.matmul(inputs, self.weight)\n",
    "        x = x + self.bias\n",
    "        return tf.nn.relu(x)\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super(CustomLayer1, self).get_config()\n",
    "        base_config['num_units'] = self.num_units\n",
    "        return base_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create another custom layer\n",
    "\n",
    "Build another custom layer with inner TF 'layers'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a custom layer, extending TF 'Layer' class.\n",
    "# Custom layer: 2 Fully-Connected layers with residual connection.\n",
    "class CustomLayer2(layers.Layer):\n",
    "    \n",
    "    # Layer arguments.\n",
    "    def __init__(self, num_units, **kwargs):\n",
    "        self.num_units = num_units\n",
    "        super(CustomLayer2, self).__init__(**kwargs)\n",
    "        \n",
    "    def build(self, input_shape):\n",
    "        shape = tf.TensorShape((input_shape[1], self.num_units))\n",
    "        \n",
    "        self.inner_layer1 = layers.Dense(1)\n",
    "        self.inner_layer2 = layers.Dense(self.num_units)\n",
    "        \n",
    "        # Make sure to call the `build` method at the end\n",
    "        super(CustomLayer2, self).build(input_shape)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.inner_layer1(inputs)\n",
    "        x = tf.nn.relu(x)\n",
    "        x = self.inner_layer2(x)\n",
    "        return x + inputs\n",
    "\n",
    "    def get_config(self):\n",
    "        base_config = super(CustomLayer2, self).get_config()\n",
    "        base_config['num_units'] = self.num_units\n",
    "        return base_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create TF Model.\n",
    "class CustomNet(Model):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(CustomNet, self).__init__()\n",
    "        # Use custom layers created above.\n",
    "        self.layer1 = CustomLayer1(64)\n",
    "        self.layer2 = CustomLayer2(64)\n",
    "        self.out = layers.Dense(num_classes, activation=tf.nn.softmax)\n",
    "\n",
    "    # Set forward pass.\n",
    "    def __call__(self, x, is_training=False):\n",
    "        x = self.layer1(x)\n",
    "        x = tf.nn.relu(x)\n",
    "        x = self.layer2(x)\n",
    "        if not is_training:\n",
    "            # tf cross entropy expect logits without softmax, so only\n",
    "            # apply softmax when not training.\n",
    "            x = tf.nn.softmax(x)\n",
    "        return x\n",
    "\n",
    "# Build neural network model.\n",
    "custom_net = CustomNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cross-Entropy loss function.\n",
    "def cross_entropy(y_pred, y_true):\n",
    "    y_true = tf.cast(y_true, tf.int64)\n",
    "    crossentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)\n",
    "    return tf.reduce_mean(crossentropy)\n",
    "\n",
    "# Accuracy metric.\n",
    "def accuracy(y_pred, y_true):\n",
    "    # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n",
    "    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n",
    "    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "# Adam optimizer.\n",
    "optimizer = tf.optimizers.Adam(learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimization process. \n",
    "def run_optimization(x, y):\n",
    "    # Wrap computation inside a GradientTape for automatic differentiation.\n",
    "    with tf.GradientTape() as g:\n",
    "        pred = custom_net(x, is_training=True)\n",
    "        loss = cross_entropy(pred, y)\n",
    "\n",
    "        # Compute gradients.\n",
    "        gradients = g.gradient(loss, custom_net.trainable_variables)\n",
    "\n",
    "        # Update W and b following gradients.\n",
    "        optimizer.apply_gradients(zip(gradients, custom_net.trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 50, loss: 3.363096, accuracy: 0.902344\n",
      "step: 100, loss: 3.344931, accuracy: 0.910156\n",
      "step: 150, loss: 3.336300, accuracy: 0.914062\n",
      "step: 200, loss: 3.318396, accuracy: 0.925781\n",
      "step: 250, loss: 3.300045, accuracy: 0.937500\n",
      "step: 300, loss: 3.335487, accuracy: 0.898438\n",
      "step: 350, loss: 3.330979, accuracy: 0.914062\n",
      "step: 400, loss: 3.298509, accuracy: 0.921875\n",
      "step: 450, loss: 3.278253, accuracy: 0.953125\n",
      "step: 500, loss: 3.285335, accuracy: 0.945312\n"
     ]
    }
   ],
   "source": [
    "# Run training for the given number of steps.\n",
    "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n",
    "    # Run the optimization to update W and b values.\n",
    "    run_optimization(batch_x, batch_y)\n",
    "    \n",
    "    if step % display_step == 0:\n",
    "        pred = custom_net(batch_x, is_training=False)\n",
    "        loss = cross_entropy(pred, batch_y)\n",
    "        acc = accuracy(pred, batch_y)\n",
    "        print(\"step: %i, loss: %f, accuracy: %f\" % (step, loss, acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
